Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix preserve_rng_state for activation checkpointing #4690

Merged
merged 6 commits into from
Mar 20, 2024

Conversation

YangFei1990
Copy link
Contributor

In the activation checkpointing implementation we have the preserve_rng_state option, if it is set to True, activation checkpointing should use the same RNG state for the two forward runs in a single step. Consider the following test script with activation checkpoint and a dropout op in the model:

import torch
import torch.utils.checkpoint
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--xla", type=int, required=True)
args = parser.parse_args()

if args.xla:
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp
    from torch_xla.distributed.fsdp import checkpoint_module

to_save = []
def save_output(output):
    to_save.append(output.detach().cpu())

class Model(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.x = torch.nn.Linear(128,128)
        self.dropout = torch.nn.Dropout(p=0.1)
        self.args = args
    
    def forward(self, inp):
        x = self.x(inp)
        output = self.dropout(x)
        if self.args.xla:
            xm.add_step_closure(save_output, args=(output,), run_async=False)
        else:
            save_output(output)
        return output

def main(args):
    if args.xla:
        device = xm.xla_device()
    else:
        device = 0
        torch.cuda.set_device(device)

    model = Model(args)
    model.to(device)

    if args.xla:
        model = checkpoint_module(model)

    _input = torch.randn(128, 128, requires_grad=True)
    _input = _input.to(device)

    if not args.xla:
        output = torch.utils.checkpoint.checkpoint(model, _input)
    else:
        output = model(_input)
    output = torch.sum(output)
    output.backward()
    if args.xla:
        xm.mark_step()
    same_output = torch.allclose(to_save[0], to_save[1])
    print(f"xla {args.xla} same_output {same_output}")

if __name__ == "__main__":    
    main(args)

If everything works right same_output should be True. However we observed without XLA it works correctly

python test_dropout_simple.py --xla 0
xla 0 same_output True

But with XLA it is wrong

python test_dropout_simple.py --xla 1
xla 0 same_output False

This PR fixed this issue by also saving/loading the XLA's RNG state in the activation checkpointing implementation. After the fix the output matches between the 2 forwards.

@JackCaoG JackCaoG self-requested a review February 27, 2023 18:37
@JackCaoG
Copy link
Collaborator

Thanks! Mostly LGTM. Can you add a test case to maybe https://github.com/pytorch/xla/blob/master/test/test_operations.py ? You can compare the result with xla device and cpu device. this way we won't regress this.

torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
with xm.fork_rng():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason not to pass the rng_devices and ctx.preserve_rng_state ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess upstream seed is handled by torch.random.fork_rng? through I am not sure why it doesn't work with pytorch/xla...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah upstream seed is handled by torch.random.fork_rng. It will fork torch seed but somehow it won't set XLA's RNG. This seed torch_xla._XLAC._xla_get_rng_seed(str(device) is it independent to torch seed? How torch XLA in general handle RNGs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not change the previous behavior, i.e. upstream seed will still be maintained as it was (check code below). I simply add another preserve RNG states.

output = torch.sum(output)
output.backward()
xm.mark_step()
same_output = torch.allclose(model.to_save[0], model.to_save[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this to_save?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_save is the container to hold the output tensor. With activation checkpointing the FWD will run twice, this container can capture both tensors. Check line 2352.

same_output = torch.allclose(model.to_save[0], model.to_save[1])
if not same_output:
print(f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")
self.assertTrue(same_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do something similar to

self.assertTrue(same_output, f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome didn't know could do that. Updating.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly, LGTM. Please address the comments.

torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
with xm.fork_rng():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?

@JackCaoG JackCaoG merged commit 66acfeb into pytorch:master Mar 20, 2024
17 checks passed
@JackCaoG
Copy link
Collaborator

I will take care of the backport

JackCaoG added a commit that referenced this pull request Mar 22, 2024
…6788)

Co-authored-by: Fei <33940270+YangFei1990@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants